Coin toss results¶

In [1]:
import jax
import matplotlib.pyplot as plt
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from scipy.stats import gaussian_kde
import plotly.express as px
import laplax
import pandas as pd
import pickle
tfd = tfp.distributions
import plotly
plotly.offline.init_notebook_mode()
In [2]:
def get_data(file_name):
    with open(file_name,'rb') as f:
        all_data = pickle.load(f)
    samples = all_data['data']
    alpha_prior = all_data['prior']['alpha']
    beta_prior = all_data['prior']['beta']
    return samples, alpha_prior, beta_prior
def get_likelihood(params, aux=None):
    return tfd.Bernoulli(probs=params['p_of_h'])
In [3]:
def plot_coin_toss_results(varient=''):
    varient = str(varient)
    samples, alpha_prior, beta_prior = get_data("../../data/coin_toss/coin_toss"+varient)
    plt.hist(samples)
    plt.ylabel("frequency")
    plt.title("Given Data")
    plt.show()
    all_labels = [] # add the labels as you go..
    all_pdfs = [] # add the pdfs as you go..
    
    x = jnp.linspace(0.01,0.99,100)
    one= jnp.sum(samples==1).astype('float32')
    zero= jnp.sum(samples==0).astype('float32')
    print(alpha_prior,beta_prior,one,zero)
    true_post_dist = tfd.Beta(alpha_prior+one,beta_prior+zero)
    true_post_pdf = true_post_dist.prob(x)
    all_labels.append("True Posterior")
    all_pdfs.append(true_post_pdf)


    with open('results_data/coin_toss_VI_Ajax_result'+varient,'rb') as f:
        variational  = pickle.load(f)
    ajax_vi_pdf = jnp.exp(variational.log_prob({"theta":x}))
    all_labels.append("AJAX VI")
    all_pdfs.append(ajax_vi_pdf)

    with open('results_data/MCMC_BlackJAX'+varient,'rb') as black_f:
        black_samples = pickle.load(black_f)
    kde_black = gaussian_kde(black_samples.position['x'][300:,0])
    
    pdf_black = kde_black(x)
    all_labels.append("Blackjax rmh estimate")
    all_pdfs.append(pdf_black)

    laplace_dict = pd.read_pickle('results_data/laplace_coin_toss'+varient)
    laplace_posterior = laplace_dict['model'].apply(laplace_dict['params'], laplace_dict['data'])
    laplace_pdf = jnp.exp(laplace_posterior.log_prob({'p_of_h': x}, sample_shape=(len(x), )))
    all_labels.append("Laplace")
    all_pdfs.append(laplace_pdf)


    laplax_dict = pd.read_pickle('results_data/laplax_coin_toss'+varient)
    laplax_posterior = laplax_dict['model'].apply(laplax_dict['params'], laplax_dict['data'])
    laplax_pdf = jnp.exp(laplax_posterior.log_prob({'p_of_h': x}, sample_shape=(len(x), )))
    all_labels.append("Laplax")
    all_pdfs.append(laplax_pdf)

    all_pdfs = jnp.array(all_pdfs).reshape((-1))
    no_estimates = len(all_labels)
    all_labels_repeated = [item for item in all_labels for i in range(x.shape[0])]
    x_repeated = jnp.tile(x,no_estimates)
    to_df = {
        "theta":x_repeated,
        "PDF":all_pdfs,
        "label": all_labels_repeated

    }
    df = pd.DataFrame(to_df)

    fig = px.line(to_df,"theta","PDF",color="label",title=f"Coin toss posterior  prior=({alpha_prior},{beta_prior})") 
    fig.show()
In [4]:
plot_coin_toss_results()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
10 15 8.0 2.0
In [5]:
plot_coin_toss_results(1)
5 1 8.0 2.0
In [6]:
!jupyter nbconvert --to HTML coin_toss_results.ipynb
[NbConvertApp] WARNING | Config option `kernel_spec_manager_class` not recognized by `NbConvertApp`.
[NbConvertApp] Converting notebook coin_toss_results.ipynb to HTML
[NbConvertApp] Writing 4315881 bytes to coin_toss_results.html